from __future__ import print_function

import math
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt 
import os
from torch.utils.data import DataLoader, TensorDataset, random_split
import itertools
from functools import partial
import random
import pandas as pd
from torch import nn
import warnings
from sklearn.exceptions import ConvergenceWarning
from collections import OrderedDict, defaultdict
import torch.nn.functional as F

class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    return optimizer


def save_model(model, optimizer, opt, epoch, save_file):
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state

def spectral_filter_and_normalize(F, num_zero_high=0, num_zero_low=0):
    """
    Applies spectral filtering on feature matrix F using SVD:
    - Zeroes out top `num_zero_high` and bottom `num_zero_low` singular values.
    - Normalizes the remaining singular values to create a uniform spectrum.
    - Returns transformed features F_transformed = F_centered @ A

    Args:
        F (torch.Tensor): Feature matrix of shape (n, d)
        num_zero_high (int): Number of largest singular values to zero
        num_zero_low (int): Number of smallest singular values to zero

    Returns:
        F_transformed (torch.Tensor): Transformed feature matrix
    """
    F = F.float()
    F_centered = F - F.mean(dim=0, keepdim=True)  # (n x d)

    # Step 2: Compute covariance matrix (d x d)
    C = F_centered.T @ F_centered / F_centered.shape[0]  # (d x d)

    # Step 3: Eigen decomposition
    eigvals, eigvecs = torch.linalg.eigh(C)  # eigvals ascending order

    # Step 4: Zero out smallest and largest eigenvalues
    eigvals_filtered = eigvals.clone()
    if num_zero_low > 0:
        eigvals_filtered[:num_zero_low] = 0
    if num_zero_high > 0:
        eigvals_filtered[-num_zero_high:] = 0

    # Step 5: Normalize remaining eigenvalues to uniform distribution (sum = 1)
    mask = eigvals_filtered > 0
    if mask.sum() > 0:
        eigvals_filtered[mask] = 1.0 / mask.sum()

    # Step 6: Construct transformation matrix A
    # A = eigvecs @ diag(sqrt(eigvals_filtered)) @ eigvecs.T
    S_sqrt = torch.diag(torch.sqrt(eigvals_filtered))
    A = eigvecs @ S_sqrt @ eigvecs.T

    # Step 7: Apply transformation
    F_transformed = F_centered @ A

    return F_transformed

def truncate_by_singular_values(F, num_singular_values):
    """
    Keeps only the top `num_singular_values` singular values of F and
    reconstructs the feature matrix using only those components.

    Args:
        F (torch.Tensor): Input feature matrix of shape (n, d)
        num_singular_values (int): Number of top singular values to retain

    Returns:
        F_reconstructed (torch.Tensor): Spectrally compressed feature matrix (n, d)
    """
    F = F.float()
    F_centered = F - F.mean(dim=0, keepdim=True)

    # Perform SVD
    U, S, Vh = torch.linalg.svd(F_centered, full_matrices=False)

    # Zero out the smaller singular values
    S_filtered = torch.zeros_like(S)
    S_filtered[:num_singular_values] = S[:num_singular_values]

    # Reconstruct: F' = U @ diag(S_filtered) @ Vh
    F_reconstructed = (U * S_filtered.unsqueeze(0)) @ Vh

    return F_reconstructed

def plot_singular_values(features, epoch, filepath_template='SpecRegLoss/SingularValueSpectrum_{i}.png'):
    """
    Plots the singular values of a given feature matrix.
    
    Args:
        features (torch.Tensor): A feature matrix of shape (n, d), 
                                 where n is the batch size and d is the feature dimension.
        title (str): Title of the plot.
    """
    # Ensure the input is a tensor
    if not isinstance(features, torch.Tensor):
        raise ValueError("Features must be a torch.Tensor")

    # Perform Singular Value Decomposition (SVD)
    with torch.no_grad():
        _, singular_values, _ = torch.linalg.svd(features)

    singular_values = singular_values / singular_values.max()

    # Convert to numpy for plotting
    singular_values = singular_values.cpu().numpy()

    # Generate ranks
    ranks = range(1, len(singular_values) + 1)

    # Plot
    plt.figure(figsize=(8, 6))
    plt.plot(ranks, singular_values, marker='o', linestyle='-', color='b')
    plt.title('Singular Value Spectrum')
    plt.xlabel("Singular Value Rank")
    plt.ylabel("Singular Value")
    plt.grid(True)
    filepath = filepath_template.format(i=epoch)
    plt.savefig(filepath)
    plt.close()


def plot_singular_values_labels_mine(gram_matrix, labels, iteration, save_path='SpecRegLoss/gram_eigenvalues_labels_plot'):
    """
    Save a plot of eigenvalues of the Gram matrix and projections of labels on eigenvectors.

    Args:
        gram_matrix (torch.Tensor): The gram_matrix matrix of shape (n, n).
        labels (torch.Tensor): The ground-truth labels of shape (n,).
        iteration (int): The iteration index for the filename.
        save_path (str): The base path to save the plot.
    """
    
    # Compute eigenvalues and eigenvectors of the Gram matrix
    eigenvalues, eigenvectors = torch.linalg.eigh(gram_matrix)

    # Sort eigenvalues and eigenvectors in descending order
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    eigenvalues = eigenvalues[sorted_indices]
    eigenvectors = eigenvectors[:, sorted_indices]

    # Normalize the eigenvalues (scale them between 0 and 1)
    eigenvalues = eigenvalues / eigenvalues.max()
    # print(eigenvalues.shape)
    # print(eigenvectors.shape)

    # Create the figure and axes
    fig, ax_left = plt.subplots(figsize=(10, 6))

    # Initialize an empty list to store the projections
    label_projections = []
    # Compute the projection of labels on each eigenvector
    for v_i in eigenvectors.T:  # Iterate over each eigenvector 
        projection_i = torch.dot(v_i, labels.float())  # Projection of y on eigenvector v_i
        label_projections.append(projection_i)  

    # Convert list of projections to tensor
    label_projections = torch.tensor(label_projections)
    label_projections = label_projections[torch.argsort(label_projections.abs(), descending=True)]

    # Apply a nonlinear transformation for better visualization
    # ax_left.set_yscale('symlog', linthresh=1e-2)  # Linear around zero, log elsewhere

    # Plot the label projections on the left y-axis
    for i, projection in enumerate(label_projections):
        ax_left.vlines(i + 1, 0, projection, color='red', alpha=0.7)
    ax_left.set_ylabel("Label Projections", color='red')
    ax_left.tick_params(axis='y', labelcolor='red')
    # ax_left.set_ylim(0, max(label_projections) * 1.2)  # Adjust scale on the left y-axis
    ax_left.set_ylim(0, 10)

    # Create the right y-axis for normalized eigenvalues
    ax_right = ax_left.twinx()
    ax_right.plot(
        range(1, len(eigenvalues) + 1),
        eigenvalues,
        label="Normalized Eigenvalues",
        linewidth=2,
        color="blue"
    )
    ax_right.set_ylabel("Normalized Eigenvalues", color='blue')
    ax_right.tick_params(axis='y', labelcolor='blue')
    # Apply symlog scaling to the right y-axis (Eigenvalues)
    # ax_right.set_yscale('symlog', linthresh=1e-2)  # Linear near zero, log elsewhere
    ax_right.set_ylim(0, 1.1)  # Keep normalized scale for eigenvalues

    # Formatting the plot
    ax_left.set_xlabel("Eigenvalue Index")
    plt.title("Eigenvalues and Label Projections with Dual Y-Axes")

    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # Save the plot
    filename = f"{save_path}_epoch_{iteration}.png"
    plt.savefig(filename)
    plt.close()


def plot_singular_values_labels(gram_matrix, labels, iteration, save_path='SpecRegLoss/gram_eigenvalues_labels_plot'):
    """
    Plot normalized eigenvalues of the Gram matrix and the projections of the label vector onto each eigenvector.

    Args:
        gram_matrix (torch.Tensor): Gram matrix (n x n), assumed symmetric.
        labels (torch.Tensor): Ground-truth label vector (n,).
        iteration (int): Iteration number (used in file naming).
        save_path (str): Base path to save the figure.
    """
    # Eigen-decomposition
    # eigenvectors, eigenvalues, Vh = torch.svd(gram_matrix)  # S is in descending order

    eigenvalues, eigenvectors = torch.linalg.eigh(gram_matrix)  # ascending order
    eigenvalues = eigenvalues.flip(0)                            # descending
    eigenvectors = eigenvectors.flip(1)                          # flip columns to match

    # Normalize eigenvalues for plotting
    eigenvalues = eigenvalues / (eigenvalues.max() + 1e-12)

    # Project labels onto eigenvectors
    labels = labels.float()
    projections = torch.matmul(eigenvectors.T, labels)  # shape: (n,)
    projections_sorted = projections[torch.argsort(projections.abs(), descending=True)]

    # Plot setup
    fig, ax_left = plt.subplots(figsize=(10, 6))

    # Plot label projections (left axis)
    ax_left.vlines(torch.arange(1, len(projections_sorted) + 1), 0, projections_sorted.cpu(), color='red', alpha=0.7)
    ax_left.set_ylabel("Label Projections", color='red')
    ax_left.tick_params(axis='y', labelcolor='red')
    ax_left.set_ylim(0, 10)
    ax_left.set_xlabel("Index (sorted by |projection|)")

    # Plot eigenvalues (right axis)
    ax_right = ax_left.twinx()
    ax_right.plot(torch.arange(1, len(eigenvalues) + 1), eigenvalues.cpu(), color='blue', linewidth=2)
    ax_right.set_ylabel("Normalized Eigenvalues", color='blue')
    ax_right.tick_params(axis='y', labelcolor='blue')
    ax_right.set_ylim(0, 1.1)

    # Title and save
    plt.title("Eigenvalues and Label Projections (Sorted by Projection Magnitude)")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    filename = f"{save_path}_epoch_{iteration}.png"
    plt.savefig(filename)
    plt.close()

# def compute_gram_matrix(X):
#     """
#     Compute the Gram matrix H_infinity for a dataset X based on the kernel associated with the ReLU function.

#     Parameters:
#         X (torch.Tensor): Data tensor of shape (n, d), where n is the number of samples and d is the feature dimension.

#     Returns:
#         torch.Tensor: Gram matrix H_infinity of shape (n, n).
#     """
    
#     if X.is_cuda:
#         X = X.cpu().detach()
#     else:
#         X = X.detach()

#     # Normalize the rows of X to unit vectors (compute cosine similarity)
#     X_norm = X / X.norm(dim=1, keepdim=True)

#     # Compute the pairwise dot products
#     cosine_similarity = torch.matmul(X_norm.T, X_norm)
    
#     # Clamp values to the range [-1, 1] to avoid numerical issues with arccos
#     cosine_similarity = torch.clamp(cosine_similarity, -1.0, 1.0)

#     # Compute the Gram matrix using the given formula
#     H_infinity = (cosine_similarity / torch.pi) - (torch.arccos(cosine_similarity) / (2 * torch.pi))

#     return H_infinity # Gram matrix has size d*d, not n*n

def compute_gram_matrix(X):
    """
    Compute the Gram matrix H_infinity for a dataset X based on the kernel associated with the ReLU function.

    Parameters:
        X (torch.Tensor): Data tensor of shape (n, d), where n is the number of samples and d is the feature dimension.

    Returns:
        torch.Tensor: Gram matrix H_infinity of shape (d, d).
    """
    # If the tensor is on GPU, move it to CPU and detach it from the computation graph
    if X.is_cuda:
        X = X.cpu().detach()
    else:
        X = X.detach()

    # Compute the covariance matrix (feature-feature relationships)
    feature_covariance = torch.matmul(X.T, X)

    # Normalize the rows and columns to unit vectors (compute cosine similarity for features)
    # feature_norm = feature_covariance / (torch.norm(X, dim=0, keepdim=True).T * torch.norm(X, dim=0, keepdim=True))

    # Clamp values to the range [-1, 1] to avoid numerical issues with arccos
    # feature_covariance = torch.clamp(feature_covariance, -1.0, 1.0)

    # Compute the Gram matrix using the given formula
    # H_infinity = (feature_covariance * (torch.pi - torch.arccos(feature_covariance))) / (2 * torch.pi)

    return feature_covariance

def get_entropy_energy_based_rank_old(W, opt):
    # perform SVD to get singular values
    reshaped_weights = W.view(W.shape[0], -1)
    U, S, V = torch.linalg.svd(reshaped_weights)
    # print(S.size())
    # print(S)

    # normalize singular values to form a probability distribution
    S_normalized = S / torch.sum(S)
    S_normalized_max = S / S.max()
    effective_rank = torch.sum(S_normalized_max > opt.rank_threshold).item()

    # calculate energy-based rank
    squared_singular_values = S_normalized ** 2
    total_energy = squared_singular_values.sum()  # total energy (sum of squared singular values)
    
    # cumulative sum of singular values squared
    cumulative_energy = torch.cumsum(squared_singular_values, dim=0)
    
    # find the number of singular values needed to reach energy_threshold (e.g., 90% energy)
    energy_based_rank = torch.sum(cumulative_energy / total_energy <= opt.energy_threshold).item()

    
    # calculate entropy
    entropy = -torch.sum(S_normalized * torch.log(S_normalized + 1e-6))
    
    return entropy, effective_rank, energy_based_rank

def get_entropy_energy_based_rank(W, opt):
    """
    Compute spectral entropy, effective rank, and energy-based rank
    of a weight matrix W ∈ R^{N×D}.

    Args:
        W (torch.Tensor): Weight tensor
        opt: Optionally used (can remove if unused)

    Returns:
        spectral_entropy (float)
        effective_rank (float)
        energy_based_rank (float)
    """
    # Flatten the weights to shape (N, D)
    reshaped_weights = W.view(W.shape[0], -1)

    # Efficient and stable SVD (returns only what’s needed)
    _, S, _ = torch.linalg.svd(reshaped_weights, full_matrices=False)

    # Square singular values to get eigenvalues of (XᵀX)
    squared_svals = S ** 2
    total_energy = torch.sum(squared_svals)
    sum_of_svals_sq = torch.sum(squared_svals ** 2)

    # Energy-Based Rank (a.k.a. participation ratio)
    energy_based_rank = (total_energy ** 2) / (sum_of_svals_sq + 1e-12)

    # Effective Rank (Shannon entropy over normalized squared singular values)
    p = squared_svals / (total_energy + 1e-12)
    p_log_p = p * torch.log(p + 1e-12)
    spectral_entropy = -torch.sum(p_log_p)
    effective_rank = torch.exp(spectral_entropy)

    return spectral_entropy.item(), effective_rank.item(), energy_based_rank.item()


def get_nuclear_norm_weight_loss(W, lambda_nuclear=0.1):
    
    reshaped_weights = W.view(W.shape[0], -1)
    U, S, V = torch.linalg.svd(reshaped_weights) 
    
    # normalize singular values to form a probability distribution
    S_normalized = S / torch.sum(S)
    train_rank = torch.sum(S > 0.3).item()

    # calculate energy-based rank
    squared_singular_values = S ** 2
    total_energy = squared_singular_values.sum()  # total energy (sum of squared singular values)
    
    # cumulative sum of singular values squared
    cumulative_energy = torch.cumsum(squared_singular_values, dim=0)
    
    # find the number of singular values needed to reach energy_threshold (e.g., 99% energy)
    energy_based_rank = torch.sum(cumulative_energy / total_energy <= 0.80).item()
    
    print('encoder rank with threshold is: ', train_rank)
    print('encoder rank with energy-based is: ', energy_based_rank)

    nuclear_loss = torch.sum(S)
    
    # multiply by regularization coefficient
    return lambda_nuclear * nuclear_loss, train_rank, energy_based_rank


class RepresentationTracker:
    def __init__(self, augmented_features=False, save_dir='rep_logs', track_projector=True):
        self.prev_encoder_out = None
        self.prev_projector_out = None
        self.encoder_deltas = []     # list of (D,) tensors
        self.projector_deltas = []   # list of (D',) tensors
        self.save_dir = save_dir
        self.track_projector = track_projector
        self.augmented_features = augmented_features
        os.makedirs(self.save_dir, exist_ok=True)

    def update(self, encoder_out, projector_out=None, step=None):
        """
        Tracks per-feature mean absolute deltas over the batch.
        
        Args:
            encoder_out (Tensor): (B, D) tensor
            projector_out (Tensor or None): (B, D') tensor
            step (int or str): Optional step index for saving
        """
        # print(encoder_out.shape)
        # print(projector_out.shape)
        bsz, n_views, dim = projector_out.shape
        if self.augmented_features == True:
            encoder_out = encoder_out.view(bsz * n_views, dim)
            projector_out = projector_out.view(bsz * n_views, dim)
        else: 
            f1, f2 = torch.split(encoder_out, [bsz, bsz], dim=0)
            encoder_out = f1
            # print(encoder_out.shape)
            projector_out = projector_out[:, 0, :] 

        enc = encoder_out.detach().cpu()
        proj = projector_out.detach().cpu() if projector_out is not None else None

        if self.prev_encoder_out is not None:
            enc_delta = torch.abs(enc - self.prev_encoder_out).mean(dim=0)  # (D,)
            self.encoder_deltas.append(enc_delta)

            if self.track_projector and proj is not None and self.prev_projector_out is not None:
                proj_delta = torch.abs(proj - self.prev_projector_out).mean(dim=0)  # (D',)
                self.projector_deltas.append(proj_delta)

        self.prev_encoder_out = enc
        if self.track_projector and proj is not None:
            self.prev_projector_out = proj

        # Optional saving for inspection/debugging
        if step is not None:
            torch.save(enc, os.path.join(self.save_dir, f"encoder_step_{step}.pt"))
            if proj is not None:
                torch.save(proj, os.path.join(self.save_dir, f"projector_step_{step}.pt"))

    def log_to_wandb(self, step):
        """
        Logs current encoder and projector deltas (per feature dimension) to WandB.
        Only logs the most recent delta.
        """
        import wandb
        if self.encoder_deltas:
            latest_enc_delta = self.encoder_deltas[-1]
            wandb.log({f"Encoder Drift/dim_{i}": val for i, val in enumerate(latest_enc_delta)}, step=step)

        if self.projector_deltas:
            latest_proj_delta = self.projector_deltas[-1]
            wandb.log({f"Projector Drift/dim_{i}": val for i, val in enumerate(latest_proj_delta)}, step=step)

    def get_drift_history(self):
        return self.encoder_deltas, self.projector_deltas

# # utils/rep_drift.py
# import os, torch, matplotlib.pyplot as plt, seaborn as sns
# sns.set_style("white")

# class DriftTracker:
#     """
#     Collects batch-mean representations and visualises:
#       • velocity heat-map   (dims × steps)
#       • histogram of total drift per dimension
#     """

#     def __init__(self, latent_dim, save_dir="drift_vis"):
#         self.D          = latent_dim
#         self.save_dir   = save_dir
#         os.makedirs(save_dir, exist_ok=True)

#         self._history = []           # list of (D,) CPU tensors
#         self._prev    = None         # previous batch mean

#     # --------------------------------------------------
#     def update(self, rep_batch):
#         """Call every training step with (B,D) tensor -- encoder+projector output."""
#         # batch mean & detach to CPU
#         mean_vec = rep_batch.detach().mean(dim=0).cpu()   # (D,)
#         self._history.append(mean_vec)

#     # --------------------------------------------------
#     def _stack_history(self):
#         return torch.stack(self._history)     # (T, D)

#     # --------------------------------------------------
#     def save_velocity_heatmap(self, file_name="vel_heatmap.png"):
#         H = self._stack_history()             # (T,D)
#         vel = torch.abs(H[1:] - H[:-1]).T     # (D, T-1)

#         plt.figure(figsize=(12, 6))
#         sns.heatmap(
#             vel.numpy(),
#             cmap="magma",
#             cbar_kws={"label": "mean |Δf| per step"},
#             xticklabels=200, yticklabels=False
#         )
#         plt.xlabel("training step"); plt.ylabel("feature dim")
#         plt.title("Velocity heat-map")
#         path = os.path.join(self.save_dir, file_name)
#         plt.tight_layout(); plt.savefig(path); plt.close()
#         return path

#     # --------------------------------------------------
#     def save_drift_histogram(self, file_name="drift_hist.png"):
#         H = self._stack_history()             # (T,D)
#         total_drift = torch.abs(H[1:] - H[:-1]).sum(dim=0)  # (D,)

#         plt.figure(figsize=(6,4))
#         plt.hist(total_drift.numpy(), bins=40, color="steelblue")
#         plt.xlabel("Σ_t |Δf_j|"); plt.ylabel("# dims")
#         plt.title("Total drift per dimension")
#         path = os.path.join(self.save_dir, file_name)
#         plt.tight_layout(); plt.savefig(path); plt.close()
#         return path

@torch.no_grad()
def estimate_coeff(loader, attr_fn, model, opt, batches=None):
    """
    Monte-Carlo estimate of  𝔼[ f(x) · g_S(x) ]  over `batches` mini-batches.
    - attr_fn(x, y, meta) -> ±1 tensor  (spurious or core indicator)
    - model returns a **scalar**   (use projector head and/or encoder -> norm -> mean)
    """
    num_encoder, num_projector, denom = 0.0, 0.0, 0
    model.eval()
    with torch.no_grad():
        for idx, data in enumerate(loader):
            if batches is not None and idx >= batches:
                break
            image = data[0]
            labels = data[1].cuda(non_blocking=True)
            meta = data[2].cuda()
            if opt.augmented_features == True:
                images = torch.cat([image[0], image[1]], dim=0)
            else: 
                images = image[0]
            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)

            if opt.method == 'SimCLR':
                # pass images through the encoder and projection head
                z_encoder = model.encoder(images) # get features
                z_projector = F.normalize(model.head(z_encoder), dim=1)
                z_encoder = torch.nn.functional.normalize(z_encoder, dim=1)
                z_projector = torch.nn.functional.normalize(z_projector, dim=1)
                f_encoder_x = z_encoder.mean(dim=1)               # shape (B,)
                f_projector_x = z_projector.mean(dim=1)
                # print(labels[0]) # 5
                # print(meta[0, 0]) # 1
                # print(meta[0, 1]) # 5
                # print(f_encoder_x[0]) # 0.0339

                g_x = attr_fn(images, labels, meta).cuda().float()  # (B,)
                # print(g_x[0])
                num_encoder  += (f_encoder_x * g_x).sum().item()
                # print(f_encoder_x[0] * g_x[0])
                num_projector  += (f_projector_x * g_x).sum().item()
                denom += images.size(0)
                # print(num_encoder)
                # print(denom)

    return num_encoder/denom, num_projector/denom

def repeat_if_needed(tensor, target_len):
    """
    Repeat rows so that `tensor.shape[0] == target_len`.

    Works for 1-D (labels) or 2-D (metadata) tensors.
    """
    if tensor.size(0) == target_len:
        return tensor                     # already aligned
    factor = target_len // tensor.size(0)
    return tensor.repeat_interleave(factor, dim=0)

def spur_attr(imgs, labels, meta):
    # in spur_cifar10 metadata[:,0] = background bit (1 = colour matches label)
    meta = repeat_if_needed(meta, imgs.size(0))     # align to 2B
    return 2*meta[:,0] - 1     # convert {0,1} → {-1,+1}

def core_attr(imgs, labels, meta, target_cls):
    """
    +1  if label == target_cls
    −1  otherwise    (one-vs-all probe)
    """
    labels = repeat_if_needed(labels, imgs.size(0)) # align to 2B
    return 2*(labels == target_cls) - 1

def avg_core_coeff(loader, model, attr_fn_core, opt, num_classes=10, device="cuda", batches=None):
    """
    Computes 1-vs-all core coefficient for every class and returns the mean.

    attr_fn_core(imgs, labels, meta, cls) must return ±1 tensor of shape (B,)
        indicating “label == cls ? +1 : -1”.

    num_classes  – how many classes to iterate over (CIFAR-10 → 10).
    batches      – None ⇒ use entire loader; else use first `batches` mini-batches.
    """
    coeffs_encoder, coeffs_projector = [], []
    for cls in range(num_classes):
        coeff_cls_encoder, coeff_cls_projector = estimate_coeff(
            loader,
            lambda i, l, m, c=cls: attr_fn_core(i, l, m, target_cls=c),
            model,
            opt,
        )
        coeffs_encoder.append(coeff_cls_encoder)
        coeffs_projector.append(coeff_cls_projector)

    return sum(coeffs_encoder) / len(coeffs_encoder), sum(coeffs_projector) / len(coeffs_projector)


def prune_projection_head(model, amount=0.3, finalise=False):
    import torch.nn.utils.prune as prune
    for i, layer in enumerate(model.head):
        if isinstance(layer, torch.nn.Linear):
            importance = -layer.weight.detach().abs()
            prune.l1_unstructured(layer, name='weight', amount=amount, importance_scores=importance)
            if finalise:
                prune.remove(layer, "weight")

# ------------------------------------------------------------
# build once – valid for every 32×32 image in SpuriousCIFAR10_chatgpt
# ------------------------------------------------------------
def make_spur_mask(h=32, w=32, device="cpu"):
    """
    Returns a boolean mask (C,H,W) where True marks the spurious pixels.
    Horizontal (red channel only) + vertical (all channels).
    """
    mask = torch.zeros(3, h, w, dtype=torch.bool, device=device)
    mid_r, mid_c = h // 2, w // 2

    # Horizontal line on red channel only
    mask[0, mid_r, :] = True

    # Vertical line on all channels
    for ch in range(3):
        mask[ch, :, mid_c] = True

    return mask    # (3, H, W)


# ------------------------------------------------------------
# split an image into core vs spurious pixels
# ------------------------------------------------------------
def split_core_spur(img, spur_mask):
    """
    img: (C,H,W) float tensor
    spur_mask: (C,H,W) bool tensor
    returns: core_pixels, spur_pixels each as (N,C) flattened
    """
    spur_mask = spur_mask.to(img.device)
    core_mask = ~spur_mask

    c, h, w = img.shape

    flat = img.permute(1, 2, 0).reshape(-1, c)       # (H*W, C)
    flat_spur_mask = spur_mask.permute(1, 2, 0).reshape(-1, c)  # same

    spur = flat[flat_spur_mask.any(dim=1)]   # any channel is spurious
    core = flat[~flat_spur_mask.any(dim=1)]

    return core, spur

import lpips

# Load LPIPS model (once globally)
LPIPS_MODEL = lpips.LPIPS(net='alex').eval().cuda()  # uses AlexNet backbone by default


def augmentation_gap(original, augmented, spur_mask, mode="cosine", eps=1e-6):
    """
    original , augmented : (C,H,W)  -- BEFORE and AFTER transform
    spur_mask            : (C,H,W)  -- spurious mask (see make_spur_mask)
    mode                 : "mse", "cosine", or "lpips"
    returns Δ (float tensor)
    """
    core0, spur0 = split_core_spur(original, spur_mask)
    core1, spur1 = split_core_spur(augmented, spur_mask)

    if mode == "mse":
        d_core = F.mse_loss(core0, core1, reduction='mean').sqrt()
        d_spur = F.mse_loss(spur0, spur1, reduction='mean').sqrt()

    elif mode == "cosine":
        d_core = 1.0 - F.cosine_similarity(core0, core1, dim=1).mean()
        d_spur = 1.0 - F.cosine_similarity(spur0, spur1, dim=1).mean()

    elif mode == "lpips":
        # LPIPS expects (1,3,H,W) inputs in [-1,1]
        def prepare(x):
            return x.unsqueeze(0) if x.ndim==3 else x

        original_pre = prepare(original)
        augmented_pre = prepare(augmented)

        d_total = LPIPS_MODEL(original_pre, augmented_pre)
        
        # For LPIPS, splitting into core and spur separately is messy,
        # so we simply measure LPIPS over full image
        d_core = d_total
        d_spur = torch.tensor(1e-3, device=d_total.device)  # tiny spur score to avoid division-by-zero

    else:
        raise ValueError(f"Unknown mode: {mode}")

    return d_core / (d_spur + eps)

def plot_metric_vs_sv(
        x,
        u_avg, avg,
        u_best=None, u_worst=None,
        best=None,   worst=None,
        ylabel='', title='', filename=''):
    """
    Draw the average metric and, if provided, the best- and worst-group
    metrics for each split.  The band between best and worst is shaded.
    """
    plt.figure(figsize=(8, 5))

    plt.plot(x, u_avg,  marker='o', label='Uniform avg')
    if u_best is not None and u_worst is not None:
        # curves
        plt.plot(x, u_best,  linestyle='--', marker='^', label='Uniform best')
        plt.plot(x, u_worst, linestyle='--', marker='v', label='Uniform worst')
        # shading
        lo = np.minimum(u_best, u_worst)
        hi = np.maximum(u_best, u_worst)
        plt.fill_between(x, lo, hi, alpha=0.15)

    plt.plot(x, avg,    marker='x', label='        Avg')
    if best is not None and worst is not None:
        plt.plot(x, best,  linestyle='-.', marker='^', label='        Best')
        plt.plot(x, worst, linestyle='-.', marker='v', label='        Worst')
        lo = np.minimum(best, worst)
        hi = np.maximum(best, worst)
        plt.fill_between(x, lo, hi, alpha=0.15)

    # ── Cosmetics ──────────────────────────────────────────
    plt.xlabel('Number of Singular Values Retained')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    os.makedirs("plots__removing_only", exist_ok=True)
    out_path = os.path.join("plots__uniform_val_sd", filename)
    plt.savefig(out_path)
    print(f"[INFO] Plot saved to {out_path}")

def sample_index(x, sample_size, replacement=False):
    if not replacement:
        assert sample_size <= len(x)
        return torch.randperm(len(x))[:sample_size]
    else:
        return torch.randint(0, len(x), size=(sample_size,))


def sample(x, sample_size, replacement=False):
    if replacement == "auto":
        replacement = sample_size > len(x)
    if replacement == True:
        return x[torch.randint(0, len(x), size=(sample_size,))]
    else:
        assert sample_size <= len(x)
        return x[torch.randperm(len(x))][:sample_size]


def generate_random_func(n, k, F2=True):
    mapping = generate_random_x(1,sample_num= 2**k, F_2=F2)
    subset = tuple(random.sample(range(n), k=k))
    return RandomFunc(mapping, subset, F2=F2), subset

def generate_random_func_with_subset(subset, F2=True):
    k = len(subset)
    mapping = generate_random_x(1,sample_num= 2**k, F_2=F2)
    return RandomFunc(mapping, subset, F2=F2)

class BooleanFunc(object):
    pass

class RandomFunc(BooleanFunc):
    def __init__(self, mapping, subset, F2=True):
        self.F2 = F2
        self.mapping = mapping.cuda()
        self.subset = subset
    def __call__(self, x_s):
        #x_s = x_s.cpu().type(torch.long)
        if not self.F2:
            x_s = torch.where(x_s == -1, 0, x_s)
        indices = torch.zeros(x_s.shape[0], dtype=torch.long).cuda()
        for c, i in enumerate(self.subset):
            indices += (x_s[:,i] * 2**c).type(torch.long)
        return self.mapping[indices].squeeze().cuda()

def generate_bit_vectors(n):
    # Generate a tensor with all possible bit values
    bit_values = torch.tensor([0, 1])

    # Compute the Cartesian product of bit_values to generate all bit vectors
    bit_vectors = torch.cartesian_prod(*[bit_values] * n)

    return bit_vectors
    
def generate_random_x(length, sample_num=None, F_2=True, unique=False):
    if sample_num:
        x = torch.randint(low=0, high=2, size=(sample_num, length), dtype=torch.float32)
        if unique:
            x = torch.unique(x,dim=0)
    else:
        values = list(itertools.product([0, 1], repeat=length))
        x = torch.tensor(values, dtype=torch.float32)
    if not F_2:
        x[x==0] = -1
    return x

def majority(x_s, subset=[]):
    assert len(subset) % 2 == 1
    return torch.sign(torch.sum(x_s[:, subset], dim=1))

def generate_majority_func(subset):
    return partial(majority, subset=subset)

def generate_random_majority(feature_len, subset_size):
    assert 0<subset_size<=feature_len
    subset = tuple(random.sample(range(feature_len), k=subset_size))
    return generate_majority_func(subset), [subset]

def generate_parity_func(subset):
    return partial(parity, subset=subset)

def generate_fixed_parity_func(func_degree):
    subset=tuple(range(func_degree))
    return partial(parity, subset=subset), [subset]

def generate_random_parity_func(feature_len, subset_size):
    assert 0<subset_size<=feature_len
    subset = tuple(random.sample(range(feature_len), k=subset_size))
    return generate_parity_func(subset), [subset]

def get_sample_space(model, feature_len, sample_num=None, batch_size=64, F_2=True,
                     unique=False):
    feature_space = generate_random_x(feature_len, sample_num = sample_num, F_2=F_2, unique=unique) 
    y_s = batch_forward(model, feature_space, batch_size=batch_size, F_2=F_2)
    return feature_space, y_s

def batch_forward(model, x, batch_size=64, F_2=True, R=False, device="cpu"):
    if isinstance(model, torch.nn.Module):
        return model_batch_forward(model, x, batch_size=batch_size, F_2=F_2, R=R, device=device)
    else:
        return func_batch_forward(model, x, batch_size=batch_size, device=device)


def model_batch_forward(model, x, batch_size=64, F_2=True, R=False, device="cpu"):
    model = model.to(device)
    feature_space_dataset = TensorDataset(x)
    feature_space_loader = DataLoader(feature_space_dataset, batch_size=batch_size, shuffle=False)
    y_s = []
    model.eval()
    with torch.no_grad():
        for x in feature_space_loader:
            x = x[0].to(device)
            if R:
                y = model(x)
            elif not F_2:
                y = model(x).argmax(dim=1)
                y[y==0] = -1
            else:
                y = model(x)
            y_s.append(y)
    y_s = torch.concat(y_s)
    return y_s


def func_batch_forward(func, x, batch_size=64, device="cpu"):
    feature_space_dataset = TensorDataset(x)
    feature_space_loader = DataLoader(feature_space_dataset, batch_size=batch_size, shuffle=False)
    y_s = []
    with torch.no_grad():
        for x in feature_space_loader:
            x = x[0].to(device)
            y = func(x)
            y_s.append(y)
    y_s = torch.concat(y_s)
    return y_s

def sample_bit_vectors(n, k):
    """
    Sample n bit vectors of length k with an equal distribution of 0s and 1s.

    Arguments:
    - n: The number of bit vectors to sample.
    - k: The length of each bit vector.

    Returns:
    - A tensor of shape (n, k) containing the sampled bit vectors.
    """

    # Sample the bit vectors
    bit_vectors = torch.randint(low=0, high=2, size=(n, k))

    return bit_vectors

def fourier_weight_estimate(model, subset, n, est_num = 100, F_2=True, R=False, device="cpu"):
    """
    model: any model that support model(tensor) = {-1,1} 
    subset: contains the index of variables for parity taking values from [0, n-1]. None represent fourier weight for the empty set.
    """
    #draw sample x
    x_s = generate_random_x(n, est_num, F_2=F_2)
    if not subset:
        parity_S = torch.ones(est_num)
    else:
        assert all([0 <= i and i < n for i in subset])
        parity = generate_parity_func(subset)
        parity_S = func_batch_forward(parity, x_s)
    parity_S = parity_S.type(torch.float).to(device)
    pred_y = batch_forward(model, x_s, F_2=F_2, R=R).type(torch.float).squeeze()
    #print(pred_y, parity_S)
    return (pred_y @ parity_S) / est_num


def fourier_weight(model, subset, n, F_2=True, R=False, device="cpu"):
    """
    model: any model that support model(tensor) = {-1,1} 
    subset: contains the index of variables for parity taking values from [0, n-1]. None represent fourier weight for the empty set.
    """
    #draw sample x
    x_s = generate_random_x(n, None, F_2=F_2)
    if not subset:
        parity_S = torch.ones(2**n)
    else:
        assert all([0 <= i and i < n for i in subset])
        parity = generate_parity_func(subset)
        parity_S = func_batch_forward(parity, x_s)
    parity_S = parity_S.type(torch.float).to(device)
    pred_y = batch_forward(model, x_s, F_2=F_2, R=R).type(torch.float).squeeze()
    #print(pred_y, parity_S)
    return (pred_y @ parity_S) / 2**n

def correlation_estimate(model, func, n, est_num = 100, F_2=True, starting_index=0, ending_index=None, device="cpu"):
    """
    model: any model that support model(tensor) = {-1,1} 
    subset: contains the index of variables for parity taking values from [0, n-1]. None represent fourier weight for the empty set.
    """
    #draw sample x
    if ending_index == None:
        ending_index = n
    x_s = generate_random_x(n, est_num, F_2=F_2)
    func_y= func(x_s[:, starting_index:ending_index]).to(device)
    func_y = func_y.type(torch.float).to(device)
    pred_y = batch_forward(model, x_s, F_2=F_2, device=device).type(torch.float).squeeze()
    #print(func_y, pred_y)
    return ((pred_y @ func_y) / est_num).cpu().item()

def fourier_dict_records_to_df(fourier_dicts):
    res_df =  pd.DataFrame(fourier_dicts)
    return res_df

    
from itertools import chain, combinations

def powerset(iterable, without_emptyset = False):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    result =  chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
    if without_emptyset:
        result.__next__()
    return result

def estimate_fourier_expansion(model, n, sample_num = 100, F_2=True, targets=[], R=False, device="cpu"):
    fourier_dict = dict()
    if not targets:
        targets = powerset(range(n))
    for subset in targets:
        fourier_weight_est = fourier_weight_estimate(model, subset, n, sample_num, F_2=F_2, R=R, device=device)
        fourier_dict[tuple(subset)] = fourier_weight_est.cpu().item()
    return fourier_dict




def parity(x_s, subset=[]):
    if not subset:
        return torch.ones(len(x_s))
    else:
        return torch.prod(x_s[:,subset], dim=1)
    


from sklearn.linear_model import LogisticRegression
def estimate_decoded_fourier_expansion(model, n, sample_num, targets, F_2=True):
    fourier_dict = dict()
    if not targets:
        targets = powerset(range(n))
    for subset in targets:
        fourier_weight_est = fourier_decoded_weight_estimate(model, subset, n, sample_num, F_2=F_2)
        fourier_dict[tuple(subset)] = fourier_weight_est.item()
    return fourier_dict

def fourier_decoded_weight_estimate(model, subset, n, est_num = 1000, F_2=True, unique=False):
    train_propotion = 0.8
    break_index = round(est_num*train_propotion)
    x_s = generate_random_x(n, est_num, unique=unique, F_2=F_2)
    if not subset:
        parity_S = torch.ones(est_num)
    else:
        assert all([0 <= i and i < n for i in subset])
        parity = generate_parity_func(subset)
        parity_S = func_batch_forward(parity, x_s)
    parity_S = parity_S.type(torch.float).cpu()
    embeddings = batch_forward_embedding(model, x_s).cpu()
    training_x, training_y = embeddings[:break_index], parity_S[:break_index]
    testing_x, testing_y = embeddings[break_index:], parity_S[break_index:]
    #print(training_x)
    lr = LogisticRegression()
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)
        lr.fit(training_x, training_y)
    return lr.score(testing_x, testing_y)

def generate_staircase_func(degree, device="cpu"):
    parity_subsets = [tuple(range(i)) for i in range(1, degree+1)]
    func = PolynomialThreshold(parity_subsets, [1 for _ in range(len(parity_subsets))], device=device)
    return func, parity_subsets

def batch_forward_embedding(model, x, batch_size=64, device="cpu"):
    model = model.to(device)
    feature_space_dataset = TensorDataset(x)
    feature_space_loader = DataLoader(feature_space_dataset, batch_size=batch_size, shuffle=False)
    embedding_s = []
    model.eval()
    with torch.no_grad():
        for x in feature_space_loader:
            x = x[0].to(device)
            embedding = model.embedding(x)
            embedding_s.append(embedding)
    emb = torch.concat(embedding_s)
    return emb

def decoded_accuracy_on_func(model, func, n, starting_index=0, ending_index=None, est_num = 1000, F_2=True, unique=False, device="cpu"):
    #Total accuracy
    if ending_index == None:
        ending_index = n
    train_propotion = 0.8
    break_index = round(est_num*train_propotion)
    x_s = generate_random_x(n, est_num, unique=unique, F_2=F_2).to(device)
    y_s = func(x_s[:, starting_index:ending_index])
    y_s = y_s.type(torch.float).cpu()
    embeddings = batch_forward_embedding(model, x_s, device=device).cpu()
    training_x, training_y = embeddings[:break_index], y_s[:break_index]
    testing_x, testing_y = embeddings[break_index:], y_s[break_index:]
    lr = LogisticRegression()
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)
        lr.fit(training_x, training_y)
    return lr.score(testing_x, testing_y)


def decoded_accuracy_on_spurious_dataset(model, dataloader, est_num=None, train_propotion=0.8):
    max_size = len(dataloader.dataset)
    if est_num == None:
        est_num = max_size
    else:
        if est_num > max_size:
            print(f"Warning: required est_num {est_num} > dataset size {max_size}. And it has been set to the dataset size")
            est_num = max_size
            
    embedding_s, core_y_s, group_y_s, spurious_y_s = batch_forward_embedding_spurious_dataset(model, dataloader, sample_size=est_num)
    break_index = round(est_num*train_propotion)
    embedding_s, core_y_s, spurious_y_s = embedding_s.cpu(), core_y_s.cpu().squeeze(), spurious_y_s.cpu().squeeze()
    training_embedding, training_core_y, training_spurious_y = embedding_s[:break_index], core_y_s[:break_index], spurious_y_s[:break_index]
    testing_embedding, testing_core_y, testing_spurious_y = embedding_s[break_index:est_num], core_y_s[break_index:est_num], spurious_y_s[break_index:est_num]
    res_dict = dict()
    lr = LogisticRegression()
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)
        lr.fit(training_embedding, training_core_y)
    res_dict["core"] = lr.score(testing_embedding, testing_core_y)
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)
        lr.fit(training_embedding, training_spurious_y)
    res_dict["spurious"] = lr.score(testing_embedding, testing_spurious_y)
    return res_dict

def accuracy_to_correlation(accuracy):
    accuracy = torch.tensor(accuracy)
    corr = accuracy - (1-accuracy)
    return corr

def batch_forward_prediction_spurious_dataset(model, dataloader, sample_size=None, required_grad=False):
    model = model.cuda()
    pred_s = []
    core_y_s = []
    group_y_s = []
    spurious_y_s = []
    read_size = 0
    model.eval()
    max_size = len(dataloader.dataset)
    if not sample_size:
        sample_size = max_size
    if max_size < sample_size:
        print(f"Warning: sample_size: {sample_size} greater than total_size: {read_size}")
    with torch.no_grad():
        for x, core_y, G, spurious_y in dataloader:
            if read_size >= sample_size:
                break
            x = x.cuda()
            pred = torch.argmax(model(x), dim=1)
            pred_s.append(pred)
            core_y_s.append(core_y)
            group_y_s.append(G)
            spurious_y_s.append(spurious_y)
            read_size += len(core_y)
    pred_s, core_y_s, group_y_s, spurious_y_s = torch.concat(pred_s), torch.concat(core_y_s),torch.concat(group_y_s), torch.concat(spurious_y_s)
    return pred_s, core_y_s, group_y_s, spurious_y_s

def batch_forward_logits_spurious_dataset(model, dataloader, sample_size=None, required_grad=False):
    model = model.cuda()
    pred_s = []
    core_y_s = []
    group_y_s = []
    spurious_y_s = []
    read_size = 0
    model.eval()
    max_size = len(dataloader.dataset)
    if not sample_size:
        sample_size = max_size
    if max_size < sample_size:
        print(f"Warning: sample_size: {sample_size} greater than total_size: {read_size}")
    with torch.no_grad():
        for x, core_y, G, spurious_y in dataloader:
            if read_size >= sample_size:
                break
            x = x.cuda()
            pred = model(x)
            pred_s.append(pred)
            core_y_s.append(core_y)
            group_y_s.append(G)
            spurious_y_s.append(spurious_y)
            read_size += len(core_y)
    pred_s, core_y_s, group_y_s, spurious_y_s = torch.concat(pred_s), torch.concat(core_y_s),torch.concat(group_y_s), torch.concat(spurious_y_s)
    return pred_s, core_y_s, group_y_s, spurious_y_s


def model_correlation_on_spurious_dataset(model, dataloader, est_num=None, train_propotion=0.8):
    max_size = len(dataloader.dataset)
    if est_num == None:
        est_num = max_size
    else:
        if est_num > max_size:
            print(f"Warning: required est_num {est_num} > dataset size {max_size}. And it has been set to the dataset size")
            est_num = max_size
            
    pred_s, core_y_s, group_y_s, spurious_y_s = batch_forward_prediction_spurious_dataset(model, dataloader, sample_size=est_num)
    pred_s, core_y_s, spurious_y_s = pred_s.cpu(), core_y_s.cpu().squeeze(), spurious_y_s.cpu().squeeze()
    pred_s, core_y_s, spurious_y_s = pred_s[:est_num], core_y_s[:est_num], spurious_y_s[:est_num]
    res_dict = dict()
    res_dict["core_accuracy"] = torch.eq(pred_s, core_y_s).float().mean().item()
    res_dict["spurious_accuracy"] = torch.eq(pred_s, spurious_y_s).float().mean().item()
    res_dict["core_correlation"] = 2*res_dict["core_accuracy"] - 1
    res_dict["spurious_correlation"] = 2*res_dict["spurious_accuracy"] - 1
    return res_dict
    
def batch_forward_embedding_spurious_dataset(model, dataloader, sample_size=None):
    model = model.cuda()
    embedding_s = []
    core_y_s = []
    group_y_s = []
    spurious_y_s = []
    read_size = 0
    model.eval()
    max_size = len(dataloader.dataset)
    if not sample_size:
        sample_size = max_size
    if max_size < sample_size:
        print(f"Warning: sample_size: {sample_size} greater than total_size: {read_size}")
    with torch.no_grad():
        for x, core_y, G, spurious_y in dataloader:
            if read_size >= sample_size:
                break
            x = x.cuda()
            embedding = model.embedding(x).cpu()
            embedding_s.append(embedding)
            core_y_s.append(core_y)
            group_y_s.append(G)
            spurious_y_s.append(spurious_y)
            read_size += len(core_y)
    embedding_s, core_y_s, group_y_s, spurious_y_s = torch.concat(embedding_s), torch.concat(core_y_s),torch.concat(group_y_s), torch.concat(spurious_y_s)
    return embedding_s, core_y_s, group_y_s, spurious_y_s

def generate_random_polynomial(n, k_s):
    assert 0 < len(k_s) and len(k_s) <= n
    subsets_all = []
    coefs_all = []
    for i in range(len(k_s)):
        subsets = random.sample(list(itertools.combinations(range(n), i+1)), k=k_s[i])
        coefs = (2 * torch.rand(k_s[i]) - 1).tolist()
        subsets_all.extend(subsets)
        coefs_all.extend(coefs)
        
    return PolynomialThreshold(subsets_all, coefs_all)
def custom_sign(x):
    signed_x = torch.sign(x)
    signed_x[signed_x==0] = 1
    return signed_x
class PolynomialThreshold(BooleanFunc):
    def __init__(self, subsets_all, coefs_all, device="cpu"):
        self.subsets_all = subsets_all
        self.coefs_all = coefs_all
        self.device = device
        self.funcs = []
        for subset in subsets_all:
            func = generate_parity_func(subset)
            self.funcs.append(func)
        self.fourier_dict = dict(zip(subsets_all, coefs_all))
    
    def __call__(self, x):
        res = torch.zeros(len(x)).to(self.device)
        for func, coef in zip(self.funcs, self.coefs_all):
            res += coef * func(x).to(self.device)
        return custom_sign(res)
    
class Majority(BooleanFunc):
    def __init__(self):
        pass
    
    def __call__(self, x):
        return torch.sum(x, dim=1)
    
class Polynomial(BooleanFunc):
    def __init__(self, subsets_all, coefs_all):
        self.subsets_all = subsets_all
        self.coefs_all = coefs_all
        self.funcs = []
        for subset in subsets_all:
            func = generate_parity_func(subset)
            self.funcs.append(func)
        self.fourier_dict = dict(zip(subsets_all, coefs_all))
    
    def __call__(self, x):
        res = torch.zeros(len(x)).cuda()
        for func, coef in zip(self.funcs, self.coefs_all):
            res += coef * func(x).cuda()
        return res